from sklearn.tree import DecisionTreeRegressor
from sklearn import tree
import graphviz as gpv

class DecisionTree(DecisionTreeRegressor):
    def __init__(self):
        DecisionTreeRegressor.__init__(self)

    # 预测数据
    def forecast(self,data: list[float])->float:
        return self.predict([data])[0]

    # 绘制决策树
    def drawDecisionTree(self,pdf_name):
        dot_data = tree.export_graphviz(self, out_file=None)
        graph = gpv.Source(dot_data)
        graph.render(pdf_name)
